import torch
from accelerate import Accelerator
from tqdm.auto import tqdm
from transformers import AdamW, get_linear_schedule_with_warmup


class trainer:
    def __init__(self, model, lr, num_warmup_steps, device):
        self.model = model
        self.optimizer = AdamW(self.model.parameters(), lr=lr)
        self.device = device
        self.num_warmup_steps = num_warmup_steps

    def train(self, loader, epochs):
        accelerator = Accelerator()

        self.model.train()

        scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=len(loader) * epochs,
        )

        self.model, self.optimizer, loader, scheduler = accelerator.prepare(
            self.model, self.optimizer, loader, scheduler
        )
        total_steps = len(loader) * epochs
        progress_bar = tqdm(range(total_steps), desc="Training")

        for epoch in range(epochs):
            for batch in loader:
                with accelerator.autocast():
                    outputs = self.model(**batch)
                    loss = outputs.loss
                accelerator.backward(loss)
                self.optimizer.step()
                scheduler.step()
                self.optimizer.zero_grad()
                progress_bar.update(1)
                progress_bar.set_postfix({"loss": loss.item()})
        progress_bar.close()

        return self.model
